Transfer learning is a type of machine learning technique that takes advantage of knowledge gained from solving a problem in one domain, and transferring it to another in order to optimize accuracy and speed.
A common application of transfer learning is using the features learned by large convolutional neural networks as additional information that we can transfer to another domain. Models like VGG16 have learned generalized features seen in the 1000 classes of ImageNet, and these features can be used for a more specialized task.

One of the biggest benefits and motivations of transfer learning is to take advantage of small datasets for domains and tasks with little data available.

There are typically two approaches in applying transfer learning to deep learning:
Applying tranfer learning to machine learning originated in a NIPS-95 (Conference on Neural Information Processing Systems) workshop on "Learning to Learn", which focused on machine-learning methods that retain and reuse previously learned knowledge.
Since then transfer learning has appeared in different contexts under many different names such as:
Multitask learning is closely related, and tries to learn multiple tasks at once. We will cover it later in the semester.
In 2005, DARPA sought to apply transfer learning with the intent of extracting the knowledge learned from a source task to apply to a target task, without necessary regard for the source task. This differed from the goal of multitask learning in which all tasks are equally important.
Today with the abundance of pretrained models available, transfer learning is very prominent in the fields of machine learning and data mining.
A domain $\mathcal{D}$ consists of a feature space $\mathcal{X}$ and marginal probability distribution $P(X)$ over the feature space, where $X = x_1, ..., x_n \in \mathcal{X}$.
Given a domain, $\mathcal{D} = \left\{\mathcal{X}, P(X)\right\}$, a task $\mathcal{T}$ consists of a label space $\mathcal{Y}$ and a conditional probability distribution $P(Y|X)$ learned from the training data.
Given a source domain $\mathcal{D}_S$, a corresponding source task $\mathcal{T}_S$, as well as a target domain $\mathcal{D}_T$ and a target task $\mathcal{T}_T$, the objective of transfer learning is to learn the target conditional probability distribution $P(Y_T|X_T)$ in $\mathcal{D_T}$ with information gained from $\mathcal{D}_S$ and $\mathcal{D}_T$ where $\mathcal{D}_S \neq \mathcal{D}_T$ or $\mathcal{T}_S \neq \mathcal{T}_T$.
Early approaches to transfer learning defined three different settings within transfer learning:
Along with these settings are four different approaches that are be used in these settings:
The table below shows how these approaches can be used in different settings:
(Sinno Jialin Pan and Qiang Yang, A Survey on Transfer Learning)
A transfer learning example in which the domains are the same, but tasks are different, would be a classifier to detect spam in email. A model could be trained on emails from multiple users. A new email user could then use this model to filter their own messages.
An example in which domains differ but tasks are the same would be a classifier to detect bicycles, but the target domain has very little data (bikes in the wild), and the source domain has a lot of data (bikes in the lab).
(Sun, B., Feng, J., & Saenko, K. (2016). Return of Frustratingly Easy Domain Adaptation)
Pre-trained CNN features are often used as a form of transfer learning in a variety of tasks such as image classification, object detection, style transfer and generative models. Large CNNs that have been trained on ImageNet learn the underlying structure of images, which is knowledge that can be transferred. A lot of recent research in computer vision and deep learning tend to use VGG16 or VGG19 as a base to extract features from the target domain.
State of the art object detection is done by YOLO, which pretrains with ImageNet, to learn knowledge from a classification task before training the model for object detection.
Johnson et al's real time style transfer use the extracted features from VGG19 in the loss function used to train the style transformation network. The features can be used as statistics measuring the style and content of images.
Another application of using VGG19 in a loss function is with super resolution. The loss is taken as the difference between the extracted features of the output image and the actual high resolution image.
Applications in NLP include learning from a large set of labeled reviews, and transferring to a model to analyze reviews for a new product.
Some code adapted from Francois Chollet:
https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html
We will create a classifier to detect if a given image contains a shark xor a dolphin.
This classifier will be trained on features extracted from VGG16. We will then "fine tune" the last convolutional block of VGG, with our classifier on top using data augmentation.
Our training data consists of only 2000 images (compared to ImageNet's 13M images), which we will see can still give us an impressive accuracy.
Data: https://drive.google.com/open?id=0B8UdlI1bRN65YW9QTW12VXFDVFk
Fully connected classifier weights: https://drive.google.com/open?id=0B8UdlI1bRN65ck9vNmRncU1XbGM
Extracted train features: https://drive.google.com/open?id=0B8UdlI1bRN65aklWbnNBWkdyb1k
Extracted validation features: https://drive.google.com/open?id=0B8UdlI1bRN65N1VpMnVpcnlHRDQ
%matplotlib inline
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
# dimensions of our images.
img_width, img_height = 150, 150
test_dir = '../data/cats_dogs'
test_datagen = ImageDataGenerator(rescale=1./255)
plot_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(img_width, img_height),
batch_size=1,
class_mode='binary')
num_images = 72
batch_size = 20
images = []
titles = []
i = 1
for inputs, labels in plot_generator:
i += 1
images.append(np.reshape(inputs, (img_width, img_height, 3)))
# Some predictions for class 0 can have underflow error
# so check if == 1
titles.append('shark' if labels[0] == 1 else 'dolphin')
if i > num_images:
break
def plot_gallery(images, titles, n_row=3, n_col=6):
"""Helper function to plot a gallery of portraits"""
plt.figure(figsize=(1.7 * n_col, 2.3 * n_row))
plt.subplots_adjust(bottom=0, left=.01, right=.99, top=.90, hspace=.35)
for i in range(n_row * n_col):
plt.subplot(n_row, n_col, i + 1)
plt.imshow(images[i])
plt.title(titles[i], size=12)
plt.xticks(())
plt.yticks(())
plot_gallery(images, titles, n_row=12, n_col=6)
Using TensorFlow backend.
Found 1000 images belonging to 2 classes.
%%time
from keras.models import Sequential, Model
from keras.layers import Dropout, Flatten, Dense, Input
from keras.applications import VGG16
from keras import optimizers
import pandas as pd
train_data_dir = '../data/train'
validation_data_dir = '../data/validation'
nb_train_samples = 2000
nb_validation_samples = 800
epochs = 50
batch_size = 16
#
# Save bottleneck features from VGG
#
# build the VGG16 network, leaving off the top classifier layer
# so we just get the features as output
print('Building VGG16...')
input_tensor = Input(shape=(img_width, img_height, 3))
base_model = VGG16(weights='imagenet', include_top=False,
input_tensor=input_tensor)
# Save training features
datagen = ImageDataGenerator(rescale=1/255)
generator = datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode=None, # class mode set to None here, because images are loaded in order,
# so we know the first 1000 are dolphins, next 1000 are sharks
shuffle=False)
print('Saving bottleneck features (train)...')
bottleneck_features_train = base_model.predict_generator(
generator, nb_train_samples // batch_size)
np.save('../features/bottleneck_features_train.npy', bottleneck_features_train)
# Save validation features
generator = datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode=None,
shuffle=False)
print('Saving bottleneck features (validation)...')
bottleneck_features_validation = base_model.predict_generator(
generator, nb_validation_samples // batch_size)
np.save('../features/bottleneck_features_validation.npy',
bottleneck_features_validation)
Building VGG16... Found 2000 images belonging to 2 classes. Saving bottleneck features (train)... Found 800 images belonging to 2 classes. Saving bottleneck features (validation)... CPU times: user 13.3 s, sys: 8.46 s, total: 21.8 s Wall time: 3min 21s
%%time
# setup params and where to save features
epochs = 30
batch_size = 20
top_model_weights_path = '../models/dolphin_sharks_fc.h5'
# Load bottleneck features
train_data = np.load('../features/dolphin_sharks_feats_train.npy')
# the first half of labels are sharks, and second half are dolphins
train_labels = np.array(
[0] * (nb_train_samples // 2) + [1] * (nb_train_samples // 2))
validation_data = np.load('../features/dolphin_sharks_feats_val.npy')
# the first half of labels are sharks, and second half are dolphins
validation_labels = np.array(
[0] * (nb_validation_samples // 2) + [1] * (nb_validation_samples // 2))
# Build model
print('Building top model...')
top_model = Sequential()
# flatten the output convolutions, some implementations also
#. perform an average pooling here to collapse the features down
top_model.add(Flatten(input_shape=train_data.shape[1:]))
# add two fully connected layers and some dropout
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(1, activation='sigmoid'))
# compile and add loss function. Cross entropy seems like a good choice
top_model.compile(optimizer='rmsprop',
loss='binary_crossentropy', metrics=['accuracy'])
# Train model
print('Training top model...')
history = top_model.fit(train_data, train_labels,
epochs=epochs,
batch_size=batch_size,
validation_data=(validation_data, validation_labels),
verbose=1)
top_model.save_weights(top_model_weights_path)
# notice that the model gets accurate on the validation VERY quickly.
# So quickly, that we need to be careful not to over train.
Building top model... Training top model... Train on 2000 samples, validate on 800 samples Epoch 1/30 2000/2000 [==============================] - 0s - loss: 0.5075 - acc: 0.8585 - val_loss: 0.2180 - val_acc: 0.9138 Epoch 2/30 2000/2000 [==============================] - 0s - loss: 0.2306 - acc: 0.9210 - val_loss: 0.2423 - val_acc: 0.9125 Epoch 3/30 2000/2000 [==============================] - 0s - loss: 0.1635 - acc: 0.9380 - val_loss: 0.4049 - val_acc: 0.8750 Epoch 4/30 2000/2000 [==============================] - 0s - loss: 0.1544 - acc: 0.9485 - val_loss: 0.2054 - val_acc: 0.9350 Epoch 5/30 2000/2000 [==============================] - 0s - loss: 0.1192 - acc: 0.9580 - val_loss: 0.2041 - val_acc: 0.9400 Epoch 6/30 2000/2000 [==============================] - 0s - loss: 0.1104 - acc: 0.9620 - val_loss: 0.1853 - val_acc: 0.9425 Epoch 7/30 2000/2000 [==============================] - 0s - loss: 0.1079 - acc: 0.9595 - val_loss: 0.2193 - val_acc: 0.9475 Epoch 8/30 2000/2000 [==============================] - 0s - loss: 0.0939 - acc: 0.9725 - val_loss: 0.2297 - val_acc: 0.9387 Epoch 9/30 2000/2000 [==============================] - 0s - loss: 0.0821 - acc: 0.9735 - val_loss: 0.3007 - val_acc: 0.9337 Epoch 10/30 2000/2000 [==============================] - 0s - loss: 0.0622 - acc: 0.9795 - val_loss: 0.2498 - val_acc: 0.9362 Epoch 11/30 2000/2000 [==============================] - 0s - loss: 0.0673 - acc: 0.9790 - val_loss: 0.2798 - val_acc: 0.9312 Epoch 12/30 2000/2000 [==============================] - 0s - loss: 0.0615 - acc: 0.9835 - val_loss: 0.3928 - val_acc: 0.9350 Epoch 13/30 2000/2000 [==============================] - 0s - loss: 0.0674 - acc: 0.9805 - val_loss: 0.3546 - val_acc: 0.9350 Epoch 14/30 2000/2000 [==============================] - 0s - loss: 0.0456 - acc: 0.9845 - val_loss: 0.2733 - val_acc: 0.9462 Epoch 15/30 2000/2000 [==============================] - 0s - loss: 0.0389 - acc: 0.9875 - val_loss: 0.2752 - val_acc: 0.9450 Epoch 16/30 2000/2000 [==============================] - 0s - loss: 0.0339 - acc: 0.9885 - val_loss: 0.3045 - val_acc: 0.9437 Epoch 17/30 2000/2000 [==============================] - 0s - loss: 0.0480 - acc: 0.9855 - val_loss: 0.3153 - val_acc: 0.9462 Epoch 18/30 2000/2000 [==============================] - 0s - loss: 0.0443 - acc: 0.9885 - val_loss: 0.4774 - val_acc: 0.9250 Epoch 19/30 2000/2000 [==============================] - 0s - loss: 0.0261 - acc: 0.9930 - val_loss: 0.3016 - val_acc: 0.9475 Epoch 20/30 2000/2000 [==============================] - 0s - loss: 0.0344 - acc: 0.9895 - val_loss: 0.3599 - val_acc: 0.9387 Epoch 21/30 2000/2000 [==============================] - 0s - loss: 0.0357 - acc: 0.9905 - val_loss: 0.3334 - val_acc: 0.9400 Epoch 22/30 2000/2000 [==============================] - 0s - loss: 0.0253 - acc: 0.9950 - val_loss: 0.3283 - val_acc: 0.9437 Epoch 23/30 2000/2000 [==============================] - 0s - loss: 0.0245 - acc: 0.9930 - val_loss: 0.3394 - val_acc: 0.9412 Epoch 24/30 2000/2000 [==============================] - 0s - loss: 0.0266 - acc: 0.9930 - val_loss: 0.3809 - val_acc: 0.9387 Epoch 25/30 2000/2000 [==============================] - 0s - loss: 0.0175 - acc: 0.9935 - val_loss: 0.3512 - val_acc: 0.9425 Epoch 26/30 2000/2000 [==============================] - 0s - loss: 0.0185 - acc: 0.9960 - val_loss: 0.4081 - val_acc: 0.9362 Epoch 27/30 2000/2000 [==============================] - 0s - loss: 0.0252 - acc: 0.9935 - val_loss: 0.4772 - val_acc: 0.9375 Epoch 28/30 2000/2000 [==============================] - 0s - loss: 0.0266 - acc: 0.9930 - val_loss: 0.4161 - val_acc: 0.9437 Epoch 29/30 2000/2000 [==============================] - 0s - loss: 0.0118 - acc: 0.9980 - val_loss: 0.4288 - val_acc: 0.9437 Epoch 30/30 2000/2000 [==============================] - 0s - loss: 0.0219 - acc: 0.9945 - val_loss: 0.4492 - val_acc: 0.9412 CPU times: user 23.2 s, sys: 2.63 s, total: 25.8 s Wall time: 16.8 s
# Plot training and validation accuracy
def plot_training_validation_acc(history, smooth=False, smooth_factor=0.8):
def smooth_curve(points, factor=0.8):
smoothed_points = []
for point in points:
if smoothed_points:
previous = smoothed_points[-1]
smoothed_points.append(previous * factor + point * (1 - factor))
else:
smoothed_points.append(point)
return smoothed_points
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
if smooth:
acc = smooth_curve(acc)
val_acc = smooth_curve(val_acc)
loss = smooth_curve(loss)
val_loss = smooth_curve(val_loss)
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
plot_training_validation_acc(history)
We get pretty good validation accuracy, especially with such a small dataset, but there seems to be some overfitting.
To improve our classifier, we can place it on top of VGG, "unfreeze" the last convolutional block of VGG, then continue training with augmented samples from our dataset.
This will "fine tune" the last block in VGG, tweaking it to our domain, as well as continuing to train the classifier.
It is necessary to pre-train the classifier. If we were to place randomly initialized layers on top, large gradient updates would wreck the learned weights in the block we are fine tuning.
%%time
epochs = 20
batch_size = 16
#
# Fine tune top convulational block
#
print('Building combined model...')
# note that it is necessary to start with a fully-trained
# classifier, in order to successfully do fine-tuning
top_model.load_weights(top_model_weights_path)
# add the model on top of the convolutional base
model = Model(inputs=base_model.input,
outputs=top_model(base_model.output))
# now let's fine tune one layer within VGG
# Freeze all blocks up to block5 (the block we are fine tuning)
set_trainable = False
for layer in base_model.layers:
if layer.name == 'block5_conv1':
set_trainable = True
if set_trainable:
layer.trainable = True
else:
layer.trainable = False
# compile the model with a SGD/momentum optimizer
# and a very slow learning rate.
model.compile(loss='binary_crossentropy',
optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
metrics=['accuracy'])
# prepare data augmentation configuration
# since we are placing the images through, let's also
# add a bit of data augmentation here to help with overfitting
train_datagen = ImageDataGenerator(rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
test_datagen = ImageDataGenerator(rescale=1/255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='binary')
# fine-tune the model
print('Training combined model...')
history = model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size,
verbose=1)
Building combined model... Found 2000 images belonging to 2 classes. Found 800 images belonging to 2 classes. Training combined model... Epoch 1/20 125/125 [==============================] - 15s - loss: 0.0873 - acc: 0.9710 - val_loss: 0.1334 - val_acc: 0.9587 Epoch 2/20 125/125 [==============================] - 13s - loss: 0.0925 - acc: 0.9665 - val_loss: 0.1252 - val_acc: 0.9600 Epoch 3/20 125/125 [==============================] - 14s - loss: 0.0904 - acc: 0.9675 - val_loss: 0.1584 - val_acc: 0.9513 Epoch 4/20 125/125 [==============================] - 14s - loss: 0.0582 - acc: 0.9775 - val_loss: 0.1525 - val_acc: 0.9613 Epoch 5/20 125/125 [==============================] - 14s - loss: 0.0648 - acc: 0.9755 - val_loss: 0.1403 - val_acc: 0.9525 Epoch 6/20 125/125 [==============================] - 14s - loss: 0.0550 - acc: 0.9825 - val_loss: 0.1005 - val_acc: 0.9738 Epoch 7/20 125/125 [==============================] - 14s - loss: 0.0661 - acc: 0.9780 - val_loss: 0.1137 - val_acc: 0.9637 Epoch 8/20 125/125 [==============================] - 14s - loss: 0.0557 - acc: 0.9810 - val_loss: 0.2089 - val_acc: 0.9500 Epoch 9/20 125/125 [==============================] - 13s - loss: 0.0755 - acc: 0.9725 - val_loss: 0.1359 - val_acc: 0.9587 Epoch 10/20 125/125 [==============================] - 14s - loss: 0.0691 - acc: 0.9720 - val_loss: 0.1486 - val_acc: 0.9513 Epoch 11/20 125/125 [==============================] - 14s - loss: 0.0580 - acc: 0.9790 - val_loss: 0.1988 - val_acc: 0.9537 Epoch 12/20 125/125 [==============================] - 14s - loss: 0.0674 - acc: 0.9780 - val_loss: 0.1943 - val_acc: 0.9475 Epoch 13/20 125/125 [==============================] - 14s - loss: 0.0528 - acc: 0.9830 - val_loss: 0.1651 - val_acc: 0.9587 Epoch 14/20 125/125 [==============================] - 14s - loss: 0.0446 - acc: 0.9830 - val_loss: 0.1854 - val_acc: 0.9587 Epoch 15/20 125/125 [==============================] - 14s - loss: 0.0378 - acc: 0.9870 - val_loss: 0.1806 - val_acc: 0.9575 Epoch 16/20 125/125 [==============================] - 14s - loss: 0.0593 - acc: 0.9800 - val_loss: 0.1185 - val_acc: 0.9550 Epoch 17/20 125/125 [==============================] - 14s - loss: 0.0377 - acc: 0.9870 - val_loss: 0.1553 - val_acc: 0.9575 Epoch 18/20 125/125 [==============================] - 14s - loss: 0.0457 - acc: 0.9830 - val_loss: 0.1406 - val_acc: 0.9550 Epoch 19/20 125/125 [==============================] - 14s - loss: 0.0441 - acc: 0.9855 - val_loss: 0.1225 - val_acc: 0.9637 Epoch 20/20 125/125 [==============================] - 14s - loss: 0.0407 - acc: 0.9880 - val_loss: 0.1565 - val_acc: 0.9537 CPU times: user 5min 19s, sys: 17.2 s, total: 5min 36s Wall time: 4min 43s
plot_training_validation_acc(history, smooth=True)
test_dir = '../data/test'
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(img_width, img_height),
batch_size=20,
class_mode='binary')
test_loss, test_acc = model.evaluate_generator(test_generator, steps=50)
print('test acc:', test_acc)
Found 1000 images belonging to 2 classes. test acc: 0.968999992609
We get an increase in validation accuracy, and there seems to be less overfitting now.
We can plot selected images from the test set and see what the classifier predicted (true label is in parentheses).
plot_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(img_width, img_height),
batch_size=1,
class_mode='binary')
num_images = 72
batch_size = 20
images = []
titles = []
predictions = []
i = 1
for inputs, labels in plot_generator:
i += 1
images.append(np.reshape(inputs, (img_width, img_height, 3)))
# Some predictions for class 0 can have underflow error
# so check if == 1
titles.append('shark' if labels[0] > .5 else 'dolphin')
y_hat = model.predict(inputs)
predictions.append('shark' if y_hat > .5 else 'dolphin')
if i > num_images:
break
for i in range(num_images):
predictions[i] += ' (%s)' % titles[i]
plot_gallery(images, predictions, n_row=12, n_col=6)
Found 1000 images belonging to 2 classes.
By fine tuning the last convolutional block, we were able to achieve a slight increase in accuracy. If we were to use data augmentation from the beginning while training our classifier intially by placing the classifier directly on top of VGG instead of saving out the features, the training time would take longer, but we could get a slight increase in accuracy.
import IPython
url = 'https://storage.googleapis.com/tfjs-examples/webcam-transfer-learning/dist/index.html'
iframe = '<iframe src=' + url + ' width=900 height=700></iframe>'
IPython.display.HTML(iframe)
Info: https://machinelearningmastery.com/transfer-learning-for-deep-learning/
http://ruder.io/transfer-learning
https://www.cse.ust.hk/~qyang/Docs/2009/tkde_transfer_learning.pdf
Image Graph: Suzuki, Masahiro & Sato, Haruhiko & Oyama, Satoshi & Kurihara, Masahito. (2014). Transfer learning based on the observation probability of each attribute. Pacman: https://github.com/tensorflow/tfjs-examples/tree/master/webcam-transfer-learning